import numpy as np
import matplotlib.pyplot as plt
import gzip


def read_idx3(filename):
    """
    读取gz格式的数据集图像部分，并返回

    :param filename: extension name of the file is '.gz'
    :return: images data, shape -> num, rows, cols
    """
    with gzip.open(filename) as fo:
        print('Reading images...')
        buf = fo.read()

        offset = 0  # 偏移量
        # 首先获取的是这个数据集的头部数据，通常是元数据。
        #   '>i'  表示顺序读取，并且数据类型为整数
        #   4  读4个单位
        #   offset 偏移量
        # 返回的是一个数组，赋值给header
        header = np.frombuffer(buf, dtype='>i', count=4, offset=offset)
        print(header)
        magic_number, num_images, num_rows, num_cols = header
        # magic number 即幻数，意义不明，只是读取时需要占位所以声明了
        print("\tmagic number: {}, number of images: {}, number of rows: {}, number of columns: {}" \
              .format(magic_number, num_images, num_rows, num_cols))
        # 计算偏移量，以读取后续的内容
        # size = 数组长度
        # itemsize = 每个元素的大小
        # 因此乘起来就是跳过header的内容，读后续的内容
        offset += header.size * header.itemsize
        # 读取真正的数据。>B 表示是二进制数据
        data = np.frombuffer(buf, '>B', num_images * num_rows * num_cols, offset).reshape(
            (num_images, num_rows, num_cols))
        # .reshape 表示按传入的参数重新构造这个数组

        return data, num_images


def read_idx1(filename):
    """
    读取gz格式的数据集标签部分，并返回

    :param filename: extension name of the file is '.gz'
    :return: labels
    """
    with gzip.open(filename) as fo:
        print('Reading labels...')
        buf = fo.read()

        offset = 0
        header = np.frombuffer(buf, '>i', 2, offset)
        magic_number, num_labels = header
        print("\tmagic number: {}, number of labels: {}" \
              .format(magic_number, num_labels))

        offset += header.size * header.itemsize

        data = np.frombuffer(buf, '>B', num_labels, offset)
        return data